{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Tce3stUlHN0L"
   },
   "source": [
    "##### Copyright 2019 The TensorFlow Authors.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2022-06-03T20:06:58.138639Z",
     "iopub.status.busy": "2022-06-03T20:06:58.138353Z",
     "iopub.status.idle": "2022-06-03T20:06:58.143137Z",
     "shell.execute_reply": "2022-06-03T20:06:58.142300Z"
    },
    "id": "tuOe1ymfHZPu"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GEe3i16tQPjo"
   },
   "source": [
    "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
    "[官方英文文档](https://tensorflow.google.cn/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
    "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
    "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xHxb-dlhMIzW"
   },
   "source": [
    "## 概述\n",
    "\n",
    "本教程使用 `tf.distribute.Strategy` API 演示了使用 Keras 模型的多工作器（worker）分布式培训。借助专为多工作器（worker）训练而设计的策略，设计在单一工作器（worker）上运行的 Keras 模型可以在最少的代码更改的情况下无缝地处理多个工作器。\n",
    "\n",
    "[TensorFlow 中的分布式培训](../../guide/distribute_strategy.ipynb)指南可用于概述TensorFlow支持的分布式策略，并想要更深入理解`tf.distribute.Strategy` API 感兴趣的人。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MUXex9ctTuDB"
   },
   "source": [
    "## 配置\n",
    "\n",
    "首先，设置 TensorFlow 和必要的导入。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-03T20:06:58.147316Z",
     "iopub.status.busy": "2022-06-03T20:06:58.147011Z",
     "iopub.status.idle": "2022-06-03T20:07:30.624846Z",
     "shell.execute_reply": "2022-06-03T20:07:30.623814Z"
    },
    "id": "bnYxvfLD-LW-"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting tf-nightly\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Downloading tf_nightly-2.10.0.dev20220603-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (505.4 MB)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting tb-nightly~=2.10.0.a\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  Downloading tb_nightly-2.10.0a20220603-py3-none-any.whl (5.8 MB)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.6.3)\r\n",
      "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.1.0)\r\n",
      "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.3.0)\r\n",
      "Requirement already satisfied: flatbuffers<2,>=1.12 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.12)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting tf-estimator-nightly~=2.10.0.dev\r\n",
      "  Downloading tf_estimator_nightly-2.10.0.dev2022060308-py2.py3-none-any.whl (438 kB)\r\n",
      "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (14.0.1)\r\n",
      "Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.1.0)\r\n",
      "Requirement already satisfied: keras-preprocessing>=1.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.1.2)\r\n",
      "Requirement already satisfied: six>=1.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.16.0)\r\n",
      "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.14.1)\r\n",
      "Requirement already satisfied: protobuf<3.20,>=3.9.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.19.4)\r\n",
      "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (4.2.0)\r\n",
      "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (21.3)\r\n",
      "Requirement already satisfied: numpy>=1.20 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.23.0rc2)\r\n",
      "Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.4.0)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting keras-nightly~=2.10.0.dev\r\n",
      "  Downloading keras_nightly-2.10.0.dev2022060307-py2.py3-none-any.whl (1.7 MB)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.2.0)\r\n",
      "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (1.46.3)\r\n",
      "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (62.3.2)\r\n",
      "Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (3.7.0)\r\n",
      "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tf-nightly) (0.26.0)\r\n",
      "Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tf-nightly) (0.37.1)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.10.0.a->tf-nightly) (0.4.6)\r\n",
      "Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.10.0.a->tf-nightly) (2.6.6)\r\n",
      "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.10.0.a->tf-nightly) (3.3.7)\r\n",
      "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.10.0.a->tf-nightly) (2.1.2)\r\n",
      "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.10.0.a->tf-nightly) (2.27.1)\r\n",
      "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.10.0.a->tf-nightly) (1.8.1)\r\n",
      "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tb-nightly~=2.10.0.a->tf-nightly) (0.6.1)\r\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from packaging->tf-nightly) (3.0.9)\r\n",
      "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.10.0.a->tf-nightly) (5.2.0)\r\n",
      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.10.0.a->tf-nightly) (0.2.8)\r\n",
      "Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tb-nightly~=2.10.0.a->tf-nightly) (4.8)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly~=2.10.0.a->tf-nightly) (1.3.1)\r\n",
      "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tb-nightly~=2.10.0.a->tf-nightly) (4.11.4)\r\n",
      "Requirement already satisfied: charset-normalizer~=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.10.0.a->tf-nightly) (2.0.12)\r\n",
      "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.10.0.a->tf-nightly) (3.3)\r\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.10.0.a->tf-nightly) (1.26.9)\r\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tb-nightly~=2.10.0.a->tf-nightly) (2022.5.18.1)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tb-nightly~=2.10.0.a->tf-nightly) (3.8.0)\r\n",
      "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tb-nightly~=2.10.0.a->tf-nightly) (0.4.8)\r\n",
      "Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly~=2.10.0.a->tf-nightly) (3.2.0)\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing collected packages: keras-nightly, tf-estimator-nightly, tb-nightly, tf-nightly\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully installed keras-nightly-2.10.0.dev2022060307 tb-nightly-2.10.0a20220603 tf-estimator-nightly-2.10.0.dev2022060308 tf-nightly-2.10.0.dev20220603\r\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.checkpoint_management has been moved to tensorflow.python.checkpoint.checkpoint_management. The old module will be deleted in version 2.9.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.resource has been moved to tensorflow.python.trackable.resource. The old module will be deleted in version 2.11.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.graph_view has been moved to tensorflow.python.checkpoint.graph_view. The old module will be deleted in version 2.11.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.python_state has been moved to tensorflow.python.trackable.python_state. The old module will be deleted in version 2.11.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.util has been moved to tensorflow.python.checkpoint.checkpoint. The old module will be deleted in version 2.11.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.saving.functional_saver has been moved to tensorflow.python.checkpoint.functional_saver. The old module will be deleted in version 2.11.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.saving.checkpoint_options has been moved to tensorflow.python.checkpoint.checkpoint_options. The old module will be deleted in version 2.11.\n"
     ]
    }
   ],
   "source": [
    "!pip install tf-nightly\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow as tf\n",
    "tfds.disable_progress_bar()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hPBuZUNSZmrQ"
   },
   "source": [
    "## 准备数据集\n",
    "\n",
    "现在，让我们从 [TensorFlow 数据集](https://tensorflow.google.cn/datasets) 中准备MNIST数据集。 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/) 包括60,000个训练样本和10,000个手写数字0-9的测试示例，格式为28x28像素单色图像。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-03T20:07:30.630669Z",
     "iopub.status.busy": "2022-06-03T20:07:30.629579Z",
     "iopub.status.idle": "2022-06-03T20:07:34.347314Z",
     "shell.execute_reply": "2022-06-03T20:07:34.346366Z"
    },
    "id": "dma_wUAxZqo2"
   },
   "outputs": [],
   "source": [
    "BUFFER_SIZE = 10000\n",
    "BATCH_SIZE = 64\n",
    "\n",
    "def make_datasets_unbatched():\n",
    "  # 将 MNIST 数据从 (0, 255] 缩放到 (0., 1.]\n",
    "  def scale(image, label):\n",
    "    image = tf.cast(image, tf.float32)\n",
    "    image /= 255\n",
    "    return image, label\n",
    "\n",
    "  datasets, info = tfds.load(name='mnist',\n",
    "                            with_info=True,\n",
    "                            as_supervised=True)\n",
    "\n",
    "  return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)\n",
    "\n",
    "train_datasets = make_datasets_unbatched().batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "o87kcvu8GR4-"
   },
   "source": [
    "## 构建 Keras 模型\n",
    "在这里，我们使用`tf.keras.Sequential` API来构建和编译一个简单的卷积神经网络 Keras 模型，用我们的 MNIST 数据集进行训练。\n",
    "\n",
    "注意：有关构建 Keras 模型的详细训练说明，请参阅[TensorFlow Keras 指南](https://tensorflow.google.cn/guide/keras#sequential_model)。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-03T20:07:34.352018Z",
     "iopub.status.busy": "2022-06-03T20:07:34.351717Z",
     "iopub.status.idle": "2022-06-03T20:07:34.357616Z",
     "shell.execute_reply": "2022-06-03T20:07:34.356696Z"
    },
    "id": "aVPHl0SfG2v1"
   },
   "outputs": [],
   "source": [
    "def build_and_compile_cnn_model():\n",
    "  model = tf.keras.Sequential([\n",
    "      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),\n",
    "      tf.keras.layers.MaxPooling2D(),\n",
    "      tf.keras.layers.Flatten(),\n",
    "      tf.keras.layers.Dense(64, activation='relu'),\n",
    "      tf.keras.layers.Dense(10)\n",
    "  ])\n",
    "  model.compile(\n",
    "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
    "      metrics=['accuracy'])\n",
    "  return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2UL3kisMO90X"
   },
   "source": [
    "让我们首先尝试用少量的 epoch 来训练模型，并在单个工作器（worker）中观察结果，以确保一切正常。 随着训练的迭代，您应该会看到损失（loss）下降和准确度（accuracy）接近1.0。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-03T20:07:34.361210Z",
     "iopub.status.busy": "2022-06-03T20:07:34.360926Z",
     "iopub.status.idle": "2022-06-03T20:07:36.638424Z",
     "shell.execute_reply": "2022-06-03T20:07:36.637411Z"
    },
    "id": "6Qe6iAf5O8iJ"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 8s - loss: 2.2997 - accuracy: 0.0625"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 2s 5ms/step - loss: 2.3034 - accuracy: 0.0875\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 0s - loss: 2.3078 - accuracy: 0.0938"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 0s 4ms/step - loss: 2.3089 - accuracy: 0.0844\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 0s - loss: 2.3156 - accuracy: 0.0938"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 0s 4ms/step - loss: 2.3126 - accuracy: 0.0875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-06-03 20:07:36.608025: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n",
      "2022-06-03 20:07:36.610575: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fc29c18fb20>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "single_worker_model = build_and_compile_cnn_model()\n",
    "single_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8YFpxrcsZ2xG"
   },
   "source": [
    "## 多工作器（worker）配置\n",
    "\n",
    "现在让我们进入多工作器（worker)训练的世界。在 TensorFlow 中，需要 `TF_CONFIG` 环境变量来训练多台机器，每台机器可能具有不同的角色。 `TF_CONFIG`用于指定作为集群一部分的每个 worker 的集群配置。\n",
    "\n",
    "`TF_CONFIG` 有两个组件：`cluster` 和 `task` 。 `cluster` 提供有关训练集群的信息，这是一个由不同类型的工作组成的字典，例如 `worker` 。在多工作器（worker）培训中，除了常规的“工作器”之外，通常还有一个“工人”承担更多责任，比如保存检查点和为 TensorBoard 编写摘要文件。这样的工作器（worker）被称为“主要”工作者，习惯上`worker` 中 `index` 0被指定为主要的 `worker`（事实上这就是`tf.distribute.Strategy`的实现方式）。 另一方面，`task` 提供当前任务的信息。\n",
    "\n",
    "在这个例子中，我们将任务 `type` 设置为 `\"worker\"` 并将任务 `index` 设置为 `0` 。这意味着具有这种设置的机器是第一个工作器，它将被指定为主要工作器并且要比其他工作器做更多的工作。请注意，其他机器也需要设置 `TF_CONFIG` 环境变量，它应该具有相同的 `cluster`  字典，但是不同的任务`type` 或 `index` 取决于这些机器的角色。\n",
    "\n",
    "为了便于说明，本教程展示了如何在 `localhost` 上设置一个带有2个工作器的`TF_CONFIG`。 实际上，用户会在外部IP地址/端口上创建多个工作器，并在每个工作器上适当地设置`TF_CONFIG`。\n",
    "\n",
    "警告：不要在 Colab 中执行以下代码。TensorFlow 的运行时将尝试在指定的IP地址和端口创建 gRPC 服务器，这可能会失败。\n",
    "\n",
    "```\n",
    "os.environ['TF_CONFIG'] = json.dumps({\n",
    "    'cluster': {\n",
    "        'worker': [\"localhost:12345\", \"localhost:23456\"]\n",
    "    },\n",
    "    'task': {'type': 'worker', 'index': 0}\n",
    "})\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P94PrIW_kSCE"
   },
   "source": [
    "注意，虽然在该示例中学习速率是固定的，但是通常可能需要基于全局批量大小来调整学习速率。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UhNtHfuxCGVy"
   },
   "source": [
    "## 选择正确的策略\n",
    "\n",
    "在 TensorFlow 中，分布式训练包括同步训练（其中训练步骤跨工作器和副本同步）、异步训练（训练步骤未严格同步）。\n",
    "\n",
    "`MultiWorkerMirroredStrategy` 是同步多工作器训练的推荐策略，将在本指南中进行演示。\n",
    "\n",
    "要训练模型，请使用 `tf.distribute.experimental.MultiWorkerMirroredStrategy` 的实例。 `MultiWorkerMirroredStrategy` 在所有工作器的每台设备上创建模型层中所有变量的副本。 它使用 `CollectiveOps` ，一个用于集体通信的 TensorFlow 操作，来聚合梯度并使变量保持同步。 [`tf.distribute.Strategy`指南](../../guide/distribute_strategy.ipynb)有关于此策略的更多详细信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-03T20:07:36.642672Z",
     "iopub.status.busy": "2022-06-03T20:07:36.642388Z",
     "iopub.status.idle": "2022-06-03T20:07:37.553198Z",
     "shell.execute_reply": "2022-06-03T20:07:37.552328Z"
    },
    "id": "1uFSHCJXMrQ-"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_46979/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "use distribute.MultiWorkerMirroredStrategy instead\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_46979/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "use distribute.MultiWorkerMirroredStrategy instead\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO\n"
     ]
    }
   ],
   "source": [
    "strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N0iv7SyyAohc"
   },
   "source": [
    "注意：解析 `TF_CONFIG` 并且在调用 `MultiWorkerMirroredStrategy.__init__()` 时启动 TensorFlow 的 GRPC 服务器，因此必须在创建`tf.distribute.Strategy`实例之前设置 `TF_CONFIG` 环境变量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FMy2VM4Akzpr"
   },
   "source": [
    "`MultiWorkerMirroredStrategy` 通过[`CollectiveCommunication`](https://github.com/tensorflow/tensorflow/blob/a385a286a930601211d78530734368ccb415bee4/tensorflow/python/distribute/cross_device_ops.py#L928)参数提供多个实现。`RING` 使用 gRPC 作为跨主机通信层实现基于环的集合。`NCCL` 使用[Nvidia 的 NCCL](https://developer.nvidia.com/nccl)来实现集体。 `AUTO` 将选择推迟到运行时。 集体实现的最佳选择取决于GPU的数量和种类以及群集中的网络互连。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "H47DDcOgfzm7"
   },
   "source": [
    "## 使用 MultiWorkerMirroredStrategy 训练模型\n",
    "\n",
    "通过将 `tf.distribute.Strategy` API集成到 `tf.keras` 中，将训练分发给多人的唯一更改就是将模型进行构建和 `model.compile()` 调用封装在 `strategy.scope()` 内部。 分发策略的范围决定了如何创建变量以及在何处创建变量，对于 MultiWorkerMirroredStrategy 而言，创建的变量为 MirroredVariable ，并且将它们复制到每个工作器上。\n",
    "\n",
    "注意：在此Colab中，以下代码可以按预期结果运行，但是由于未设置`TF_CONFIG`，因此这实际上是单机训练。 在您自己的示例中设置了 `TF_CONFIG` 后，您应该期望在多台机器上进行培训可以提高速度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-03T20:07:37.557369Z",
     "iopub.status.busy": "2022-06-03T20:07:37.557025Z",
     "iopub.status.idle": "2022-06-03T20:07:46.692260Z",
     "shell.execute_reply": "2022-06-03T20:07:46.691376Z"
    },
    "id": "BcsuBYrpgnlS"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-06-03 20:07:38.409119: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 30s - loss: 2.3111 - accuracy: 0.0312"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - ETA: 0s - loss: 2.3121 - accuracy: 0.0656 "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 8s 15ms/step - loss: 2.3121 - accuracy: 0.0656\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 0s - loss: 2.3031 - accuracy: 0.1016"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - ETA: 0s - loss: 2.3034 - accuracy: 0.0922"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 0s 14ms/step - loss: 2.3034 - accuracy: 0.0922\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 0s - loss: 2.3033 - accuracy: 0.1172"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - ETA: 0s - loss: 2.3047 - accuracy: 0.1047"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 0s 14ms/step - loss: 2.3047 - accuracy: 0.1047\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-06-03 20:07:46.666346: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n",
      "2022-06-03 20:07:46.669124: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fc2842d79d0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NUM_WORKERS = 2\n",
    "# 由于 `tf.data.Dataset.batch` 需要全局的批处理大小，\n",
    "# 因此此处的批处理大小按工作器数量增加。\n",
    "# 以前我们使用64，现在变成128。\n",
    "GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS\n",
    "\n",
    "# 创建数据集需要在 MultiWorkerMirroredStrategy 对象\n",
    "# 实例化后。\n",
    "train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE)\n",
    "with strategy.scope():\n",
    "  # 模型的建立/编译需要在 `strategy.scope()` 内部。\n",
    "  multi_worker_model = build_and_compile_cnn_model()\n",
    "\n",
    "# Keras 的 `model.fit()` 以特定的时期数和每时期的步数训练模型。\n",
    "# 注意此处的数量仅用于演示目的，并不足以产生高质量的模型。\n",
    "multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Rr14Vl9GR4zq"
   },
   "source": [
    "### 数据集分片和批（batch）大小\n",
    "\n",
    "在多工作器训练中，需要将数据分片为多个部分，以确保融合和性能。 但是，请注意，在上面的代码片段中，数据集直接发送到`model.fit（）`，而无需分片； 这是因为`tf.distribute.Strategy` API在多工作器训练中会自动处理数据集分片。\n",
    "\n",
    "如果您喜欢手动分片进行训练，则可以通过`tf.data.experimental.DistributeOptions` API关闭自动分片。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-03T20:07:46.696228Z",
     "iopub.status.busy": "2022-06-03T20:07:46.695952Z",
     "iopub.status.idle": "2022-06-03T20:07:46.701935Z",
     "shell.execute_reply": "2022-06-03T20:07:46.701132Z"
    },
    "id": "JxEtdh1vH-TF"
   },
   "outputs": [],
   "source": [
    "options = tf.data.Options()\n",
    "options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF\n",
    "train_datasets_no_auto_shard = train_datasets.with_options(options)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NBCtYvmCH-7g"
   },
   "source": [
    "要注意的另一件事是 `datasets` 的批处理大小。 在上面的代码片段中，我们使用 `GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS` ，这是单个工作器的大小的 `NUM_WORKERS` 倍，因为每个工作器的有效批量大小是全局批量大小（参数从 `tf.data.Dataset.batch()` 传入）除以工作器的数量，通过此更改，我们使每个工作器的批处理大小与以前相同。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XVk4ftYx6JAO"
   },
   "source": [
    "## 性能\n",
    "\n",
    "现在，您已经有了一个Keras模型，该模型全部通过 `MultiWorkerMirroredStrategy` 运行在多个工作器中。 您可以尝试以下技术来调整多工作器训练的效果。\n",
    "\n",
    "*   `MultiWorkerMirroredStrategy` 提供了多个[集体通信实现][collective communication implementations](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/distribute/cross_device_ops.py).  `RING` 使用gRPC作为跨主机通信层实现基于环的集合。  `NCCL` 使用 [Nvidia's NCCL](https://developer.nvidia.com/nccl) 来实现集合。  `AUTO` 将推迟到运行时选择。集体实施的最佳选择取决于GPU的数量和种类以及集群中的网络互连。 要覆盖自动选择，请为 `MultiWorkerMirroredStrategy` 的构造函数的 `communication` 参数指定一个有效值，例如： `communication=tf.distribute.experimental.CollectiveCommunication.NCCL`.\n",
    "*    如果可能的话，将变量强制转换为 `tf.float`。ResNet 的官方模型包括如何完成此操作的[示例](https://github.com/tensorflow/models/blob/8367cf6dabe11adf7628541706b660821f397dce/official/resnet/resnet_model.py#L466)。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "97WhAu8uKw3j"
   },
   "source": [
    "## 容错能力\n",
    "\n",
    "在同步训练中，如果其中一个工作器出现故障并且不存在故障恢复机制，则集群将失败。 在工作器退出或不稳定的情况下，将 Keras 与 `tf.distribute.Strategy` 一起使用会具有容错的优势。 我们通过在您选择的分布式文件系统中保留训练状态来做到这一点，以便在重新启动先前失败或被抢占的实例后，将恢复训练状态。\n",
    "\n",
    "由于所有工作器在训练 epochs 和 steps 方面保持同步，因此其他工作器将需要等待失败或被抢占的工作器重新启动才能继续。\n",
    "\n",
    "### ModelCheckpoint 回调\n",
    "\n",
    "要在多工作器训练中利用容错功能，请在调用 `tf.keras.Model.fit()` 时提供一个 `tf.keras.callbacks.ModelCheckpoint` 实例。 回调会将检查点和训练状态存储在与 `ModelCheckpoint` 的 `filepath` 参数相对应的目录中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-03T20:07:46.706141Z",
     "iopub.status.busy": "2022-06-03T20:07:46.705611Z",
     "iopub.status.idle": "2022-06-03T20:07:54.413366Z",
     "shell.execute_reply": "2022-06-03T20:07:54.412424Z"
    },
    "id": "xIY9vKnUU82o"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-06-03 20:07:46.828733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 20s - loss: 2.3033 - accuracy: 0.1484"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - ETA: 0s - loss: 2.3035 - accuracy: 0.1531 "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.trackable_utils has been moved to tensorflow.python.trackable.trackable_utils. The old module will be deleted in version 2.11.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.trackable_utils has been moved to tensorflow.python.trackable.trackable_utils. The old module will be deleted in version 2.11.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 6s 202ms/step - loss: 2.3035 - accuracy: 0.1531\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 0s - loss: 2.3021 - accuracy: 0.1250"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - ETA: 0s - loss: 2.3003 - accuracy: 0.1656"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 1s 166ms/step - loss: 2.3003 - accuracy: 0.1656\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3/3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "1/5 [=====>........................] - ETA: 0s - loss: 2.2945 - accuracy: 0.1250"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - ETA: 0s - loss: 2.2974 - accuracy: 0.1406"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r\n",
      "5/5 [==============================] - 1s 167ms/step - loss: 2.2974 - accuracy: 0.1406\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-06-03 20:07:54.384936: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n",
      "2022-06-03 20:07:54.388366: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fc2747485b0>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将 `filepath` 参数替换为在文件系统中所有工作器都能访问的路径。\n",
    "callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='/tmp/keras-ckpt')]\n",
    "with strategy.scope():\n",
    "  multi_worker_model = build_and_compile_cnn_model()\n",
    "multi_worker_model.fit(x=train_datasets,\n",
    "                       epochs=3,\n",
    "                       steps_per_epoch=5,\n",
    "                       callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ii6VmEdOjkZr"
   },
   "source": [
    "如果某个工作线程被抢占，则整个集群将暂停，直到重新启动被抢占的工作线程为止。工作器重新加入集群后，其他工作器也将重新启动。 现在，每个工作器都将读取先前保存的检查点文件，并获取其以前的状态，从而使群集能够恢复同步，然后继续训练。\n",
    "\n",
    "如果检查包含在`ModelCheckpoint` 中指定的 `filepath` 的目录，则可能会注意到一些临时生成的检查点文件。 这些文件是恢复以前丢失的实例所必需的，并且在成功退出多工作器训练后，这些文件将在 `tf.keras.Model.fit()` 的末尾被库删除。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ega2hdOQEmy_"
   },
   "source": [
    "## 您可以查阅\n",
    "1. [Distributed Training in TensorFlow](https://tensorflow.google.cn/guide/distribute_strategy) 该指南概述了可用的分布式策略。\n",
    "2. [ResNet50](https://github.com/tensorflow/models/blob/master/official/resnet/imagenet_main.py) 官方模型，该模型可以使用 `MirroredStrategy` 或 `MultiWorkerMirroredStrategy` 进行训练"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "multi_worker_with_keras.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
